# 加载环境变量
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
import os
import json
import sqlite3

_ = load_dotenv(find_dotenv())  # 读取本地 .env 文件，里面定义了 OPENAI_API_KEY

client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL")
)


# 移除字典中的 None 值
def remove_none_values(data):
    if isinstance(data, dict):
        return {key: remove_none_values(value) for key, value in data.items() if value is not None}
    elif isinstance(data, list):
        return [remove_none_values(item) for item in data if item is not None]
    else:
        return data


# 打印参数。如果参数是有结构的（如字典或列表），则以格式化的 JSON 形式打印；
def print_json(data):
    if (isinstance(data, (list, dict))):
        print(json.dumps(data, indent=4, ensure_ascii=False))
    else:
        print(data)


# 对象序列化
def serialize_json(data):
    if hasattr(data, 'model_dump_json'):
        data = json.loads(data.model_dump_json())

    if (isinstance(data, (list, dict))):
        data = remove_none_values(data)
    return data


def get_sql_completion(params, model="gpt-3.5-turbo"):
    print("=====GPT请求=====")
    messages = serialize_json(params)
    print_json(messages)

    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0,
        tools=[{
            # 摘自 OpenAI 官方示例 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb
            "type": "function",
            "function": {
                "name": "ask_database",
                "description": "Use this function to answer user questions about business. \
                                Output should be a fully formed SQL query.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {
                            "type": "string",
                            "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                The query should only contain grammars supported by SQLite.
                                """,
                        }
                    },
                    "required": ["query"],
                }
            }
        }],
    )
    result = serialize_json(response.choices[0].message)
    print("=====GPT回复=====")
    print_json(result)
    return result


#  描述数据库表结构
database_schema_string = """
CREATE TABLE orders (
    id INT PRIMARY KEY NOT NULL, -- 主键，不允许为空
    customer_id INT NOT NULL, -- 客户ID，不允许为空
    product_id STR NOT NULL, -- 产品ID，不允许为空
    price DECIMAL(10,2) NOT NULL, -- 价格，不允许为空
    status INT NOT NULL, -- 订单状态，整数类型，不允许为空。0代表待支付，1代表已支付，2代表已退款
    create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间，默认为当前时间
    pay_time TIMESTAMP -- 支付时间，可以为空
);
"""

# 创建数据库连接
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()

# 创建orders表
cursor.execute(database_schema_string)

# 插入5条明确的模拟记录
mock_data = [
    (1, 1001, 'TSHIRT_1', 50.00, 0, '2023-10-12 10:00:00', None),
    (2, 1001, 'TSHIRT_2', 75.50, 1, '2023-10-16 11:00:00', '2023-08-16 12:00:00'),
    (3, 1002, 'SHOES_X2', 25.25, 2, '2023-10-17 12:30:00', '2023-08-17 13:00:00'),
    (4, 1003, 'HAT_Z112', 60.75, 1, '2023-10-20 14:00:00', '2023-08-20 15:00:00'),
    (5, 1002, 'WATCH_X001', 90.00, 0, '2023-10-28 16:00:00', None)
]

for record in mock_data:
    cursor.execute('''
    INSERT INTO orders (id, customer_id, product_id, price, status, create_time, pay_time)
    VALUES (?, ?, ?, ?, ?, ?, ?)
    ''', record)

# 提交事务
conn.commit()


def ask_database(query):
    cursor.execute(query)
    records = cursor.fetchall()
    return records


prompt = "10月的销售额"
# prompt = "统计每月每件商品的销售额"
# prompt = "哪个用户消费最高？消费多少？"

messages = [
    {"role": "system", "content": "基于 order 表回答用户问题"},
    {"role": "user", "content": prompt}
]
response = get_sql_completion(messages)
messages.append(response)

# 递归调用，直到没有 tool_calls 为止
while 'tool_calls' in response and response['tool_calls'] is not None:
    for tool_call in response['tool_calls']:
        if tool_call['function']['name'] == "ask_database":
            arguments = tool_call['function']['arguments']
            args = json.loads(arguments)
            print("SQL:", args["query"])
            result = ask_database(args["query"])
            print("DB Records:", result)

            messages.append({
                "tool_call_id": tool_call['id'],
                "role": "tool",
                "name": tool_call['function']['name'],
                "content": str(result)
            })
            response = get_sql_completion(messages)
            print("====最终回复====")
            print(response['content'])
